/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package w10gan;

import lombok.EqualsAndHashCode;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;

/**
 * Wasserstein loss function, which calculates the Wasserstein distance, also known as earthmover's distance.
 *
 * This is not necessarily a general purpose loss function, and is intended for use as a discriminator loss.
 *
 * When using in a discriminator, use a label of 1 for real and -1 for generated
 * instead of the 1 and 0 used in normal GANs.
 *
 * As described in <a href="https://papers.nips.cc/paper/5679-learning-with-a-wasserstein-loss.pdf">Learning with a Wasserstein Loss</a>
 *
 * @author Ryan Nett
 */
@EqualsAndHashCode(callSuper = false)
public class Loss10GradientPenalty implements ILossFunction {

    public Loss10GradientPenalty(INDArray interImg){
        this.averaged_samples = interImg;
    }
    private INDArray averaged_samples;
    private int lam = 10 ;//梯度惩罚权重
    private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask){

        long rows = labels.shape()[0];// [60,1]

        INDArray FL = labels.get(NDArrayIndex.interval(0,rows/3));
        INDArray RL = labels.get(NDArrayIndex.interval(rows/3,rows*2/3));
        INDArray ML = labels.get(NDArrayIndex.interval(rows*2/3,rows));

        INDArray F = preOutput.get(NDArrayIndex.interval(0,rows/3));
        INDArray R = preOutput.get(NDArrayIndex.interval(rows/3,rows*2/3));
        INDArray M = preOutput.get(NDArrayIndex.interval(rows*2/3,rows));

        INDArray F1 = activationFn.getActivation(F.dup(), true);
        INDArray R1 = activationFn.getActivation(R.dup(), true);



        labels = labels.get(NDArrayIndex.interval(rows*2/3,rows));
        INDArray dLda = labels.div(labels.size(1));
        INDArray gradients = activationFn.backprop(M, dLda).getFirst();


      /*  INDArray labelsA = labels.get(NDArrayIndex.all(), NDArrayIndex.interval(0,rows/3));
        INDArray labelsB = labels.get(NDArrayIndex.all(), NDArrayIndex.interval(rows/3,2*rows/3));*//*
        INDArray labelsC = labels.get(NDArrayIndex.interval(rows*2/3,rows));
       *//* INDArray alpha = Nd4j.ones(new long[]{28 *28,1});*//*
        INDArray gradients = activationFn.backprop(M,labelsC).getFirst();
        */
        // compute the euclidean norm by squaring ...
        INDArray gradients_sqr = gradients.muli(gradients);
        //  ... summing over the rows ...
        //INDArray axis = Nd4j.arange(1, gradients_sqr.shape().length);
        long[] axis = gradients_sqr.shape();

        for(int i=1;i<axis.length;i++){
            gradients_sqr = gradients_sqr.cumsum(i);
        }
        INDArray gradients_sqr_sum = gradients_sqr;
        //  ... and sqrt
        INDArray gradient_l2_norm = Transforms.sqrt(gradients_sqr_sum);
        //compute lambda * (1 - ||grad||)^2 still for each single sample
        INDArray gradient_l2_norm_sqr = Nd4j.ones(gradient_l2_norm.shape()).subi(gradient_l2_norm);
        INDArray gradient_penalty = gradient_l2_norm_sqr.muli(gradient_l2_norm_sqr);
       // gradient_penalty = activationFn.getActivation(gradient_penalty, true);
        INDArray score = F1.sub(R1).add(gradient_penalty);

        return  Nd4j.vstack(score, score,score);
    }

    @Override
    public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask,
                               boolean average) {
        INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask);

        double score = scoreArr.mean(1).sumNumber().doubleValue();

        if (average) {
            score /= scoreArr.size(0);
        }
       // System.out.println("gradient:"+score);
        return score;
    }

    @Override
    public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask);
        return Nd4j.expandDims(scoreArr.mean(), 1);
    }

    @Override
    public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        long rows = labels.shape()[0];// [60,1]
        INDArray dLda = labels.div(labels.size(1));

        if (mask != null && LossUtil.isPerOutputMasking(dLda, mask)) {
            LossUtil.applyMask(labels, mask);
        }
        INDArray F = preOutput.get(NDArrayIndex.interval(0,rows/3));
        INDArray R = preOutput.get(NDArrayIndex.interval(rows/3,rows*2/3));
        INDArray M = preOutput.get(NDArrayIndex.interval(rows*2/3,rows));
        //INDArray out = activationFn.getActivation(preOutput.dup(), true);
        labels = labels.get(NDArrayIndex.interval(rows*2/3,rows));
        INDArray dLda1 = labels.div(labels.size(1));
        INDArray gradients = activationFn.backprop(M, dLda1).getFirst();

        INDArray gradients_sqr = gradients.muli(gradients);
        //  ... summing over the rows ...
        //INDArray axis = Nd4j.arange(1, gradients_sqr.shape().length);
        long[] axis = gradients_sqr.shape();

        for(int i=1;i<axis.length;i++){
            gradients_sqr = gradients_sqr.cumsum(i);
        }
        INDArray gradients_sqr_sum = gradients_sqr;
        //  ... and sqrt
        INDArray gradient_l2_norm = Transforms.sqrt(gradients_sqr_sum);
        //compute lambda * (1 - ||grad||)^2 still for each single sample
        INDArray gradient_l2_norm_sqr = Nd4j.ones(gradient_l2_norm.shape()).subi(gradient_l2_norm);
        INDArray gradient_penalty = gradient_l2_norm_sqr.muli(gradient_l2_norm_sqr);
       // long n = M.size(1);

        INDArray score =gradient_penalty;
        score = activationFn.backprop(score, dLda1).getFirst();

        return Nd4j.vstack(score, score,score);
    }

    @Override
    public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
            INDArray mask, boolean average) {
        return new Pair<>(computeScore(labels, preOutput, activationFn, mask, average),
                computeGradient(labels, preOutput, activationFn, mask));
    }
    @Override
    public String name() {
        return toString();
    }

    @Override
    public String toString() {
        return "Loss7Gradient()";
    }
}
